import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))

import torch
from models.EEGNet import EEGNet


def main():
    batch = 2
    channels = 32
    datapoints = 128
    num_classes = 3

    model = EEGNet(num_electrodes=channels, datapoints=datapoints, num_classes=num_classes)
    x = torch.randn(batch, channels, datapoints)
    y = model(x)
    print("OK - forward pass shape:", y.shape)


if __name__ == "__main__":
    main()
